import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from pppicodet.bbox.utils import batch_distance2bbox, bbox_center, bbox2distance
from ..assigner import ATSSAssigner, TaskAlignedAssigner
from .network_blocks import ConvBNLayer, get_activation, ConvNormLayer, DGQP, DPModule, Integral
from .losses import VarifocalLoss, FocalLoss, BboxLoss, DistributionFocalLoss, GIoULoss

eps = 1e-9

def get_static_shape(tensor):
    shape = torch.shape(tensor)
    shape.stop_gradient = True
    return shape

class PicoSE(nn.Layer):
    def __init__(self, feat_channels):
        super(PicoSE, self).__init__()
        self.fc = nn.Conv2d(feat_channels, feat_channels, 1)
        self.conv = ConvNormLayer(feat_channels, feat_channels, 1, 1)

        self._init_weights()

    def _init_weights(self):
        # normal_(self.fc.weight, std=0.001)
        nn.init.normal_(self.fc.weight, std=0.001)

    def forward(self, feat, avg_feat):
        weight = F.sigmoid(self.fc(avg_feat))
        out = self.conv(feat * weight)
        return out

class PicoFeat(nn.Module):
    """
    PicoFeat of PicoDet
    Args:
        feat_in (int): The channel number of input Tensor.
        feat_out (int): The channel number of output Tensor.
        num_convs (int): The convolution number of the LiteGFLFeat.
        norm_type (str): Normalization type, 'bn'/'sync_bn'/'gn'.
        share_cls_reg (bool): Whether to share the cls and reg output.
        act (str): The act of per layers.
        use_se (bool): Whether to use se module.
    """

    def __init__(self,
                 feat_in=256,
                 feat_out=96,
                 num_fpn_stride=3,
                 num_convs=2,
                 norm_type='bn',
                 share_cls_reg=False,
                 act='hardswish',
                 use_se=False):
        super(PicoFeat, self).__init__()
        self.num_convs = num_convs
        self.norm_type = norm_type
        self.share_cls_reg = share_cls_reg
        self.act = act
        self.use_se = use_se
        self.cls_convs = []
        self.reg_convs = []
        if use_se:
            assert share_cls_reg == True, \
                'In the case of using se, share_cls_reg is not supported'
            self.se = nn.ModuleList()
        for stage_idx in range(num_fpn_stride):
            cls_subnet_convs = []
            reg_subnet_convs = []
            for i in range(self.num_convs):
                in_c = feat_in if i == 0 else feat_out
                cls_conv_dw = ConvNormLayer(
                        ch_in=in_c,
                        ch_out=feat_out,
                        filter_size=5,
                        stride=1,
                        groups=feat_out,
                        norm_type=norm_type,
                        bias_on=False,
                        lr_scale=2.)
                cls_subnet_convs.append(cls_conv_dw)
                cls_conv_pw = ConvNormLayer(
                        ch_in=in_c,
                        ch_out=feat_out,
                        filter_size=1,
                        stride=1,
                        norm_type=norm_type,
                        bias_on=False,
                        lr_scale=2.)
                cls_subnet_convs.append(cls_conv_pw)

                if not self.share_cls_reg:
                    reg_conv_dw = ConvNormLayer(
                            ch_in=in_c,
                            ch_out=feat_out,
                            filter_size=5,
                            stride=1,
                            groups=feat_out,
                            norm_type=norm_type,
                            bias_on=False,
                            lr_scale=2.)
                    reg_subnet_convs.append(reg_conv_dw)
                    reg_conv_pw = ConvNormLayer(
                            ch_in=in_c,
                            ch_out=feat_out,
                            filter_size=1,
                            stride=1,
                            norm_type=norm_type,
                            bias_on=False,
                            lr_scale=2.)
                    reg_subnet_convs.append(reg_conv_pw)
            self.cls_convs.append(cls_subnet_convs)
            self.reg_convs.append(reg_subnet_convs)
            if use_se:
                self.se.append(PicoSE(feat_out))

    def act_func(self, x):
        if self.act == "lrelu":
            x = get_activation(x)
        elif self.act == "hardswish":
            x = get_activation(x)
        return x

    def forward(self, fpn_feat, stage_idx):
        assert stage_idx < len(self.cls_convs)
        cls_feat = fpn_feat
        reg_feat = fpn_feat
        for i in range(len(self.cls_convs[stage_idx])):
            cls_feat = self.act_func(self.cls_convs[stage_idx][i](cls_feat))
            reg_feat = cls_feat
            if not self.share_cls_reg:
                reg_feat = self.act_func(self.reg_convs[stage_idx][i](reg_feat))
        if self.use_se:
            avg_feat = F.adaptive_avg_pool2d(cls_feat, (1, 1))
            se_feat = self.act_func(self.se[stage_idx](cls_feat, avg_feat))
            return cls_feat, se_feat
        return cls_feat, reg_feat

class PicoHead(nn.Module):
    def __init__(self, num_classes=80, fpn_stride=[8, 16, 32, 64], prior_prob=0.01,
                 reg_max=7, feat_in_chan=128, nms=None, nms_pre=1000, cell_offset=0.5):
        super(PicoHead, self).__init__()

        self.conv_feat = PicoFeat()
        self.dgqp_module = DGQP()
        self.num_classes = num_classes
        self.fpn_stride = fpn_stride
        self.prior_prob = prior_prob
        self.loss_vfl = VarifocalLoss().cuda()
        self.loss_dfl = DistributionFocalLoss().cuda()
        self.loss_bbox = GIoULoss().cuda()
        self.reg_max = reg_max
        self.feat_in_chan = feat_in_chan
        self.nms = nms
        self.nms_pre = nms_pre
        self.cell_offset = cell_offset
        self.use_sigmoid = self.loss_vfl.use_sigmoid
        if self.use_sigmoid:
            self.cls_out_channels = self.num_classes
        else:
            self.cls_out_channels = self.num_classes + 1
        bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)

        self.head_cls_list = []
        self.head_reg_list = []
        for i in range(len(fpn_stride)):
            head_cls = nn.Conv2d(
                            in_channels=self.feat_in_chan,
                            out_channels=self.cls_out_channels + 4 * (self.reg_max + 1)
                            if self.conv_feat.share_cls_reg else self.cls_out_channels,
                            kernel_size=1,
                            stride=1,
                            padding=0,)
            self.head_cls_list.append(head_cls)
            nn.init.normal_(head_cls.weight.data, mean=0., std=0.01)
            nn.init.constant_(head_cls.bias.data, val=bias_init_value)
            if not self.conv_feat.share_cls_reg:
                head_reg = nn.Conv2d(
                            in_channels=self.feat_in_chan,
                            out_channels=4 * (self.reg_max + 1),
                            kernel_size=1,
                            stride=1,
                            padding=0,)
                self.head_reg_list.append(head_reg)
                nn.init.normal_(head_cls.weight.data, mean=0., std=0.01)
                nn.init.constant_(head_cls.bias.data, val=0.)

    def forward(self, fpn_feats, deploy=False):

        assert len(fpn_feats) == len(self.fpn_stride)

        cls_logits_list = []
        bboxes_reg_list = []
        for i ,fpn_feat in enumerate(fpn_feats):
            print("fpn feat shape:", fpn_feat.shape)
            conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat, i)
            if self.conv_feat.share_cls_reg:
                cls_logits = self.head_cls_list[i](conv_cls_feat)
                cls_score, bbox_pred = torch.split(cls_logits, [self.cls_out_channels, 4*(self.reg_max+1)], dim=1)
            else:
                cls_score = self.head_cls_list[i](conv_cls_feat)
                bbox_pred = self.head_reg_list[i](conv_reg_feat)

            if self.dgqp_module:
                quality_score = self.dgqp_module(bbox_pred)
                cls_score = F.sigmoid(cls_score) * quality_score

            if deploy:
                cls_score = F.sigmoid(cls_score).reshape([1, self.cls_out_channels, -1]).transpose([0, 2, 1])
                bbox_pred = bbox_pred.reshape([1, (self.reg_max + 1) * 4, -1]).transpose([0, 2, 1])

            elif not self.training:
                cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1]))
                bbox_pred = bbox_pred.transpose([0, 2, 3, 1])

            cls_logits_list.append(cls_score)
            bboxes_reg_list.append(bbox_pred)

        return (cls_logits_list, bboxes_reg_list)

class PicoHeadV2(nn.Module):
    """
    PicoHeadV2
    Args:
        conv_feat (object): Instance of 'PicoFeat'
        num_classes (int): Number of classes
        fpn_stride (list): The stride of each FPN Layer
        prior_prob (float): Used to set the bias init for the class prediction layer
        loss_class (object): Instance of VariFocalLoss.
        loss_dfl (object): Instance of DistributionFocalLoss.
        loss_bbox (object): Instance of bbox loss.
        assigner (object): Instance of label assigner.
        reg_max: Max value of integral set :math: `{0, ..., reg_max}`
                n QFL setting. Default: 7.
    """
    __inject__ = [
        'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
        'static_assigner', 'assigner', 'nms'
    ]
    __shared__ = ['num_classes', 'eval_size']

    def __init__(self,
                 conv_feat='PicoFeatV2',
                 dgqp_module=None,
                 num_classes=80,
                 fpn_stride=[8, 16, 32],
                 prior_prob=0.01,
                 use_align_head=True,
                 loss_class='VariFocalLoss',
                 loss_dfl='DistributionFocalLoss',
                 loss_bbox='GIoULoss',
                 static_assigner_epoch=60,
                 static_assigner='ATSSAssigner',
                 assigner='TaskAlignedAssigner',
                 reg_max=16,
                 feat_in_chan=96,
                 nms=None,
                 nms_pre=1000,
                 cell_offset=0,
                 act='hardswish',
                 grid_cell_scale=5.0,
                 eval_size=None):
        super(PicoHeadV2, self).__init__(
            conv_feat=conv_feat,
            dgqp_module=dgqp_module,
            num_classes=num_classes,
            fpn_stride=fpn_stride,
            prior_prob=prior_prob,
            loss_class=loss_class,
            loss_dfl=loss_dfl,
            loss_bbox=loss_bbox,
            reg_max=reg_max,
            feat_in_chan=feat_in_chan,
            nms=nms,
            nms_pre=nms_pre,
            cell_offset=cell_offset, )
        self.conv_feat = conv_feat
        self.num_classes = num_classes
        self.fpn_stride = fpn_stride
        self.prior_prob = prior_prob
        self.loss_vfl = loss_class
        self.loss_dfl = loss_dfl
        self.loss_bbox = loss_bbox

        self.static_assigner_epoch = static_assigner_epoch
        self.static_assigner = static_assigner
        self.assigner = assigner

        self.reg_max = reg_max
        self.feat_in_chan = feat_in_chan
        self.nms = nms
        self.nms_pre = nms_pre
        self.cell_offset = cell_offset
        self.act = act
        self.grid_cell_scale = grid_cell_scale
        self.use_align_head = use_align_head
        self.cls_out_channels = self.num_classes
        self.eval_size = eval_size

        bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
        # Clear the super class initialization
        self.gfl_head_cls = None
        self.gfl_head_reg = None
        self.scales_regs = None

        self.distribution_project = Integral(self.reg_max)

        self.head_cls_list = []
        self.head_reg_list = []
        self.cls_align = nn.ModuleList()

        for i in range(len(fpn_stride)):
            head_cls = nn.Conv2d(
                    in_channels=self.feat_in_chan,
                    out_channels=self.cls_out_channels,
                    kernel_size=1,
                    stride=1,
                    padding=0
            )
            self.head_cls_list.append(head_cls)
            head_reg = nn.Conv2d(
                    in_channels=self.feat_in_chan,
                    out_channels=4 * (self.reg_max + 1),
                    kernel_size=1,
                    stride=1,
                    padding=0
            )
            self.head_reg_list.append(head_reg)
            if self.use_align_head:
                self.cls_align.append(
                    DPModule(
                        self.feat_in_chan,
                        1,
                        5,
                        act=self.act,
                        use_act_in_out=False))

        # initialize the anchor points
        if self.eval_size:
            self.anchor_points, self.stride_tensor = self._generate_anchors()

    def forward(self, feats, export_post_process=True):
        assert len(feats) == len(self.fpn_strides), \
            "The size of feats is not equal to size of fpn_strides"

        if self.training:
            return self.forward_train(feats)
        else:
            return self.forward_eval(feats, export_post_process)

    def forward_train(self, feats):
        cls_score_list, reg_list, box_list = [], [], []
        for i, (fpn_feat, stride) in enumerate(zip(feats, self.fpn_stride)):
            b, _, h, w = get_static_shape(fpn_feat)
            # task decomposition
            conv_cls_feat, se_feat = self.conv_feat(fpn_feat, i)
            cls_logit = self.head_cls_list[i](se_feat)
            reg_pred = self.head_reg_list[i](se_feat)

            # cls prediction and alignment
            if self.use_align_head:
                cls_prob = F.sigmoid(self.cls_align[i](conv_cls_feat))
                cls_score = (F.sigmoid(cls_logit) * cls_prob + eps).sqrt()
            else:
                cls_score = F.sigmoid(cls_logit)

            cls_score_out = cls_score.transpose([0, 2, 3, 1])
            bbox_pred = reg_pred.transpose([0, 2, 3, 1])
            b, cell_h, cell_w, _ = torch.shape(cls_score_out)
            y, x = self.get_single_level_center_point(
                [cell_h, cell_w], stride, cell_offset=self.cell_offset)
            center_points = torch.stack([x, y], axis=-1)
            cls_score_out = cls_score_out.reshape(
                [b, -1, self.cls_out_channels])
            bbox_pred = self.distribution_project(bbox_pred) * stride
            bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])
            bbox_pred = batch_distance2bbox(
                center_points, bbox_pred, max_shapes=None)
            cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
            reg_list.append(reg_pred.flatten(2).transpose([0, 2, 1]))
            box_list.append(bbox_pred / stride)

        cls_score_list = torch.cat(cls_score_list, axis=1)
        box_list = torch.cat(box_list, axis=1)
        reg_list = torch.cat(reg_list, axis=1)
        return cls_score_list, reg_list, box_list, feats

    def forward_eval(self, feats, export_post_process):
        if self.eval_size:
            anchor_points, stride_tensor = self.anchor_points, self.stride_tensor
        else:
            anchor_points, stride_tensor = self._generate_anchors(feats)
        cls_score_list, box_list = [], []
        for i, (fpn_feat, stride) in enumerate(zip(feats, self.fpn_stride)):
            b, _, h, w = fpn_feat.shape
            # task decomposition
            conv_cls_feat, se_feat = self.conv_feat(fpn_feat, i)
            cls_logit = self.head_cls_list[i](se_feat)
            reg_pred = self.head_reg_list[i](se_feat)

            # cls prediction and alignment
            if self.use_align_head:
                cls_prob = F.sigmoid(self.cls_align[i](conv_cls_feat))
                cls_score = (F.sigmoid(cls_logit) * cls_prob + eps).sqrt()
            else:
                cls_score = F.sigmoid(cls_logit)

            if not export_post_process:
                # Now only supports batch size = 1 in deploy
                cls_score_list.append(
                    cls_score.reshape([1, self.cls_out_channels, -1]).transpose(
                        [0, 2, 1]))
                box_list.append(
                    reg_pred.reshape([1, (self.reg_max + 1) * 4, -1]).transpose(
                        [0, 2, 1]))
            else:
                l = h * w
                cls_score_out = cls_score.reshape([b, self.cls_out_channels, l])
                bbox_pred = reg_pred.transpose([0, 2, 3, 1])
                bbox_pred = self.distribution_project(bbox_pred)
                bbox_pred = bbox_pred.reshape([b, l, 4])
                cls_score_list.append(cls_score_out)
                box_list.append(bbox_pred)

        if export_post_process:
            cls_score_list = torch.cat(cls_score_list, axis=-1)
            box_list = torch.cat(box_list, axis=1)
            box_list = batch_distance2bbox(anchor_points, box_list)
            box_list *= stride_tensor

        return cls_score_list, box_list
    
    def get_loss(self, head_outs, gt_meta):
        pred_scores, pred_regs, pred_bboxes, fpn_feats = head_outs
        gt_labels = gt_meta['gt_class']
        gt_bboxes = gt_meta['gt_bbox']
        gt_scores = gt_meta['gt_score'] if 'gt_score' in gt_meta else None
        num_imgs = gt_meta['im_id'].shape[0]
        pad_gt_mask = gt_meta['pad_gt_mask']

        anchors, _, num_anchors_list, stride_tensor_list = self.generate_anchors_for_grid_cell(
            fpn_feats, self.fpn_stride, self.grid_cell_scale, self.cell_offset)

        centers = bbox_center(anchors)

        # label assignment
        if gt_meta['epoch_id'] < self.static_assigner_epoch:
            assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner(
                anchors,
                num_anchors_list,
                gt_labels,
                gt_bboxes,
                pad_gt_mask,
                bg_index=self.num_classes,
                gt_scores=gt_scores,
                pred_bboxes=pred_bboxes.detach() * stride_tensor_list)

        else:
            assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
                pred_scores.detach(),
                pred_bboxes.detach() * stride_tensor_list,
                centers,
                num_anchors_list,
                gt_labels,
                gt_bboxes,
                pad_gt_mask,
                bg_index=self.num_classes,
                gt_scores=gt_scores)

        assigned_bboxes /= stride_tensor_list

        centers_shape = centers.shape
        flatten_centers = centers.expand(
            [num_imgs, centers_shape[0], centers_shape[1]]).reshape([-1, 2])
        flatten_strides = stride_tensor_list.expand(
            [num_imgs, centers_shape[0], 1]).reshape([-1, 1])
        flatten_cls_preds = pred_scores.reshape([-1, self.num_classes])
        flatten_regs = pred_regs.reshape([-1, 4 * (self.reg_max + 1)])
        flatten_bboxes = pred_bboxes.reshape([-1, 4])
        flatten_bbox_targets = assigned_bboxes.reshape([-1, 4])
        flatten_labels = assigned_labels.reshape([-1])
        flatten_assigned_scores = assigned_scores.reshape(
            [-1, self.num_classes])

        pos_inds = torch.nonzero(
            torch.logical_and((flatten_labels >= 0),
                               (flatten_labels < self.num_classes)),
            as_tuple=False).squeeze(1)

        num_total_pos = len(pos_inds)

        if num_total_pos > 0:
            pos_bbox_targets = torch.gather(
                flatten_bbox_targets, pos_inds, axis=0)
            pos_decode_bbox_pred = torch.gather(
                flatten_bboxes, pos_inds, axis=0)
            pos_reg = torch.gather(flatten_regs, pos_inds, axis=0)
            pos_strides = torch.gather(flatten_strides, pos_inds, axis=0)
            pos_centers = torch.gather(
                flatten_centers, pos_inds, axis=0) / pos_strides

            weight_targets = flatten_assigned_scores.detach()
            weight_targets = torch.gather(
                weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0)

            pred_corners = pos_reg.reshape([-1, self.reg_max + 1])
            target_corners = bbox2distance(pos_centers, pos_bbox_targets,
                                           self.reg_max).reshape([-1])
            # regression loss
            loss_bbox = torch.sum(
                self.loss_bbox(pos_decode_bbox_pred,
                               pos_bbox_targets) * weight_targets)

            # dfl loss
            loss_dfl = self.loss_dfl(
                pred_corners,
                target_corners,
                weight=weight_targets.expand([-1, 4]).reshape([-1]),
                avg_factor=4.0)
        else:
            loss_bbox = torch.zeros([1])
            loss_dfl = torch.zeros([1])

        avg_factor = flatten_assigned_scores.sum()
        '''
        if paddle.fluid.core.is_compiled_with_dist(
        ) and parallel_helper._is_parallel_ctx_initialized():
            paddle.distributed.all_reduce(avg_factor)
            avg_factor = paddle.clip(
                avg_factor / paddle.distributed.get_world_size(), min=1)
        '''
        loss_vfl = self.loss_vfl(
            flatten_cls_preds, flatten_assigned_scores, avg_factor=avg_factor)

        loss_bbox = loss_bbox / avg_factor
        loss_dfl = loss_dfl / avg_factor

        loss_states = dict(
            loss_vfl=loss_vfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)

        return loss_states

    def _generate_anchors(self, feats=None, device='cuda:0'):
        # just use in eval time
        anchor_points = []
        stride_tensor = []
        for i, stride in enumerate(self.fpn_strides):
            if feats is not None:
                _, _, h, w = feats[i].shape
            else:
                h = int(self.eval_input_size[0] / stride)
                w = int(self.eval_input_size[1] / stride)
            shift_x = torch.arange(end=w, device=device) + self.grid_cell_offset
            shift_y = torch.arange(end=h, device=device) + self.grid_cell_offset
            shift_y, shift_x = torch.meshgrid(shift_y, shift_x)
            anchor_point = torch.stack(
                    [shift_x, shift_y], axis=-1).to(torch.float)
            # anchor_point = paddle.cast(
            #     paddle.stack(
            #         [shift_x, shift_y], axis=-1), dtype='float32')
            anchor_points.append(anchor_point.reshape([-1, 2]))
            stride_tensor.append(
                torch.full(
                    (h * w, 1), stride, dtype=torch.float, device=device))
        anchor_points = torch.cat(anchor_points)
        stride_tensor = torch.cat(stride_tensor)
        return anchor_points, stride_tensor

    def get_single_level_center_point(self, featmap_size, stride,
                                      cell_offset=0):
        """
        Generate pixel centers of a single stage feature map.
        Args:
            featmap_size: height and width of the feature map
            stride: down sample stride of the feature map
        Returns:
            y and x of the center points
        """
        h, w = featmap_size
        x_range = (torch.arange(w, dtype='float32') + cell_offset) * stride
        y_range = (torch.arange(h, dtype='float32') + cell_offset) * stride
        y, x = torch.meshgrid(y_range, x_range)
        y = y.flatten()
        x = x.flatten()
        return y, x

    def post_process(self, head_outs, scale_factor, export_nms=True):
        pred_scores, pred_bboxes = head_outs
        if not export_nms:
            return pred_bboxes, pred_scores
        else:
            # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
            scale_y, scale_x = torch.split(scale_factor, 2, axis=-1)
            scale_factor = torch.cat(
                [scale_x, scale_y, scale_x, scale_y],
                axis=-1).reshape([-1, 1, 4])
            # scale bbox to origin image size.
            pred_bboxes /= scale_factor
            bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
            return bbox_pred, bbox_num
    
    def generate_anchors_for_grid_cell(self, feats, fpn_strides, grid_cell_size=5.0, grid_cell_offset=0.5,
                                       device='cpu'):
        r"""
        Like ATSS, generate anchors based on grid size.
        Args:
            feats (List[Tensor]): shape[s, (b, c, h, w)]
            fpn_strides (tuple|list): shape[s], stride for each scale feature
            grid_cell_size (float): anchor size
            grid_cell_offset (float): The range is between 0 and 1.
        Returns:
            anchors (Tensor): shape[l, 4], "xmin, ymin, xmax, ymax" format.
            anchor_points (Tensor): shape[l, 2], "x, y" format.
            num_anchors_list (List[int]): shape[s], contains [s_1, s_2, ...].
            stride_tensor (Tensor): shape[l, 1], contains the stride for each scale.
        """
        assert len(feats) == len(fpn_strides)
        anchors = []
        anchor_points = []
        num_anchors_list = []
        stride_tensor = []
        for feat, stride in zip(feats, fpn_strides):
            _, _, h, w = feat.shape
            cell_half_size = grid_cell_size * stride * 0.5
            shift_x = (torch.arange(end=w, device=device) + grid_cell_offset) * stride
            shift_y = (torch.arange(end=h, device=device) + grid_cell_offset) * stride
            shift_y, shift_x = torch.meshgrid(shift_y, shift_x)
            anchor = torch.stack(
                [
                    shift_x - cell_half_size, shift_y - cell_half_size,
                    shift_x + cell_half_size, shift_y + cell_half_size
                ],
                axis=-1).clone().to(feat.dtype)
            anchor_point = torch.stack(
                [shift_x, shift_y], axis=-1).clone().to(feat.dtype)

            anchors.append(anchor.reshape([-1, 4]))
            anchor_points.append(anchor_point.reshape([-1, 2]))
            num_anchors_list.append(len(anchors[-1]))
            stride_tensor.append(
                torch.full(
                    [num_anchors_list[-1], 1], stride, dtype=feat.dtype))
        anchors = torch.cat(anchors)
        anchor_points = torch.cat(anchor_points).cuda()
        stride_tensor = torch.cat(stride_tensor).cuda()
        return anchors, anchor_points, num_anchors_list, stride_tensor